% Save this .m file and all other .m and .mat files in the
% 'quantitative_model' directory on the Harvard Dataverse to a local
% directory named "quantitative_model".
% Set the working directory to
cd 'LOCALPATH/quantitative_model'

% Run .m-file to initialize parameters:
assign_parameters

nRBgoods = floor(ngoods*lambda); % Number of goods in the rust belt in Home
nSBgoods = ngoods - nRBgoods; % Number of goods in the sun belt in Home
if nRBgoods*nSBgoods == 0
    error('Home country has either no rust belt or no sun belt goods. There must be at least 1 good in either region. Adjust ngoods or lambda.');
end
zeta_initial = zeta_F_RB; % Initial value of zeta, to be used when printing moments
%%% structure holding parameters to be passed into functions
Par.ngoods = ngoods; % Number of goods in each economy
Par.nRBgoods = nRBgoods; % Number of goods in the rust belt in Home
Par.nSBgoods = nSBgoods; % Number of goods in the sun belt in Home
Par.sigma = sigma; % Inter-sectoral elasticity of substitution
Par.rho = rho; % Intra-sector elasticity of substitution between H and F
Par.tau = tau0; % Iceberg trade cost
Par.kappa = kappa; % Fraction of output lost
% Initial value for investment rate in this counterfactual experiment is:
s = 0.41; % Fraction of retained profits reinvested
Par.s = s;
% Growth rate of investment rate changes each period by:
s_change = 0.00282;
Par.sigmazH = sigmazH; % Dispersion parameter for distribution of zH
Par.LS = LS; % Labor supply in each country

% Construct a vector of phi, with one value for each good
vphi = [ones(1,nRBgoods).*phi_R ones(1,nSBgoods).*phi_S];

%%% First nRBgood columns correspond to rust belt goods in Home
%%% Other Home goods are in the sun belt.
target_strike_rate = [ones(1,nRBgoods).*target_strike_rate_RB ones(1,nSBgoods).*target_strike_rate_SB];

%%% set random number seed
%%% Code optimized for this seed; initial guess off for other seed, which
%%% means the code will run far slower
rng(2022);

%%% Draw initial productivity levels for each good
%%% 2 rows, ngoods columns
%%% Row 1 for Home goods; row 2 for Foreign goods
zbar0 = repmat([zbarH0;zbarF0],1,ngoods);
sigmaz0 = repmat([sigmazH0;sigmazF0],1,ngoods);
mz = lognrnd(zbar0,sigmaz0,2,ngoods);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Initialize time series matrices of results
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%% Aggregate variables
%%% 1 row, 1 column
%%% heights: time periods
wF_t = zeros(1,1,nPeriods); % foreign wage
S1_t = zeros(1,1,nPeriods); % S1 term in agg solution
S2_t = zeros(1,1,nPeriods); % S2 term in agg solutoin
strike_rate_RB_t = zeros(1,1,nPeriods); % ex post strike rate in the rust belt
strike_rate_SB_t = zeros(1,1,nPeriods); % ex post strike rate in the sun belt
strike_rate_overall_t = zeros(1,1,nPeriods); % ex post strike rate in Home country
strike_prob_RB_t = zeros(1,1,nPeriods); % average probability of having strikes in the rust belt
strike_prob_SB_t = zeros(1,1,nPeriods); % average probability of having strikes in the sun belt
strike_prob_overall_t = zeros(1,1,nPeriods); % % average probability of having strikes in Home country
P_H_t = zeros(1,1,nPeriods); % price index for final goods, home
P_F_t = zeros(1,1,nPeriods); % price index for final goods, foreign
Y_H_t = zeros(1,1,nPeriods); % final output at home
Y_F_t = zeros(1,1,nPeriods); % final output foreign

%%% Firm-level strike, request and rent series
%%% 1 row corresponds to home (or foreign for Pi_F)
%%% colums: goods
%%% heights: time period
strike_t = zeros(1,ngoods,nPeriods); % actual strike outcomes
request_t = zeros(1,ngoods,nPeriods); % union requests
rent_t = zeros(1,ngoods,nPeriods); % union rents
pi_H_t = zeros(1,ngoods,nPeriods); % profits from goods produced in Home
pi_F_t = zeros(1,ngoods,nPeriods); % profits from goods produced in Foreign
vztildeH_t = zeros(1,ngoods,nPeriods); % eqm cutoff values for rejecting union offer

%%% Firm-country level outcomes
%%% Series of z shocks, output, profits for all goods and home,foreign
%%% rows: home (row 1) vs foreign (row 2)
%%% colums: goods
%%% heights: time period
z_t = zeros(2,ngoods,nPeriods); % shocks by good, country
sector_y_t = zeros(2,ngoods,nPeriods); % output by good, country
i_t = zeros(2,ngoods,nPeriods); % investment by good, country

%%% Origin-destination level outcomes
%%% outcomes: labor, output, prices
%%% rows: home produced, home sales (row 1), home produced, foreign sales (row 2),
%%% foreign produced, home sales (row 3), foreign produced, foreign sales (row 4).
%%% columns: goods
%%% heights: time periods
l_t = zeros(4,ngoods,nPeriods);
y_t = zeros(4,ngoods,nPeriods);
p_t = zeros(4,ngoods,nPeriods);

%%% Initial guess for aggregate variables (updated each period)
%%% col 1 = S, col 2 = S*, col 3 = w_F (see notes)
%%% initial_guess_agg = log([0.001699 0.001756 0.993441]);
load agg_1950_solution.mat
initial_guess_agg = log(agg_sol);

%%% Initial guess for labor (updated each period)
%%% rows: home firm, home sales (row 1), home firm, foreign sales (row 2),
%%% foreign firm home sales (row 3), foreign firm, foreign sales (row 4).
%%% cols: sectors
%%initial_guess_labor = repmat(log([0.9864;0.0136;0.0129;0.9871]./ngoods),1,ngoods);
load labor_1950_solution.mat
initial_guess_labor = log(sol_mlabor);

%%% Initial guess for cutoff productivity (updated each period)
%%% 1 row: home firm
%%% cols: sectors
%initial_guess_ztildeH = log([logninv(target_strike_rate_RB,log(mz(1,1:nRBgoods))-sigmazH^2/2,sigmazH) logninv(target_strike_rate_SB,log(mz(1,nRBgoods+1:end))-sigmazH^2/2,sigmazH)]);
load ztildeH_1950_solution.mat;
initial_guess_ztildeH = log(vztildeH);

%%% Load up solution from last run as initial guess for this run
%%% This loads up the variables wF_t_sol, l_t_sol, ztildeH_t_sol, S1_t_sol 
%%% and S2_t_sol which are used below to generate a guess for each period
load wF_t_solution.mat;
load l_t_solution.mat;
load ztildeH_t_solution.mat;
load S1_t_solution.mat;
load S2_t_solution.mat;
        
%%% Initial guess matrix to be passed into functions
%%% 5 rows: row 1 to row 4 for initial guess of labor allocation; row 5 for
%%% initial guess of ztilde
%%% cols: sectors
mguess = [initial_guess_labor;initial_guess_ztildeH];

%%% Initialize run time counter
run_time = 0;

        
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Solve Model in All Periods
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

disp('Rust Belt Model');

for t=1:nPeriods

    %%% record start time for this period
    tstart = tic;
    disp(' ');
    disp(['Solving period ', num2str(t),' of ',num2str(nPeriods)]);

    %%% Solve model for this period 
    zH_expost = lognrnd(log(mz(1,:))-sigmazH^2/2,sigmazH,1,ngoods); %% draw actual productivity realizations
    options = optimoptions(@fsolve,'Display','off','optimalityTolerance',1e-5);
    objfun = @(logagg) residual_for_agg_var(Par,vphi,mz,logagg,zH_expost,mguess);
    [agg_sol,err_agg] = fsolve(objfun,initial_guess_agg,options);
    agg_sol = exp(agg_sol);
    
    %%% Use eqm aggregate solution to now compute eqm good-country outcomes
    %%% Initilize matrices to store solutions
    mlabor = zeros(4,ngoods);
    err_l = mlabor;
    vztildeH = zeros(1,ngoods);
    err_ztildeH = vztildeH;

    %%% Solve each firm's problem once more at the eqm aggregate solution
    parfor igood = 1:ngoods
        lguess = initial_guess_labor(:,igood);
        zguess = initial_guess_ztildeH(1,igood);
        phi = vphi(1,igood);
        vz = mz(:,igood);
        objfun_l = @(vloglabor) residual_for_labor(Par,phi,vz,agg_sol,vloglabor,zguess);
        [l_sol,err] = fsolve(objfun_l,lguess,options);
        mlabor(:,igood) = exp(l_sol);
        err_l(:,igood) = err;
        objfun_z = @(logztildeH) residual_for_ztilde(Par,phi,vz,agg_sol,exp(l_sol),logztildeH);
        [z_sol,err] = fsolve(objfun_z,zguess,options);
        vztildeH(igood) = exp(z_sol);
        err_ztildeH(igood) = err;
    end
    %Sol = compute_eq_values_union(Par,[zH_expost;mz(2,:)],agg_sol_union,ml,vztildeH);
    Sol = compute_eq_values(Par,vphi,mz,agg_sol,mlabor,vztildeH,zH_expost);
    
    disp(['  Aggregate Solution: ', num2str(agg_sol(1,1),'%.6f'),' ',num2str(agg_sol(1,2),'%.6f'),' ',num2str(agg_sol(1,3),'%.6f')  ]);
    disp(['  Max error for aggregate conditions = ', num2str(max(abs(err_agg),[],'all'),'%.8f')]);
    disp(['  Max error for labor optimality condition = ', num2str(max(abs(err_l),[],'all'),'%.8f')]);
    disp(['  Max error for strike cutoff optimality condition = ', num2str(max(abs(err_ztildeH),[],'all'),'%.8f')]);

    %%% Record equilibrium objects for this period
    wF_t(1,1,t) = Sol.wF; % foreign wage
    S1_t(1,1,t) = agg_sol(1,1); % S1 term
    S2_t(1,1,t) = agg_sol(1,2); % S2 term
    vztildeH_t(1,:,t) = vztildeH; % ztildeH solutions
    strike_rate_RB_t(1,1,t) = Sol.strikerate_RB; % ex post strike rate in rust belt
    strike_rate_SB_t(1,1,t) = Sol.strikerate_SB; % ex post strike rate in sun belt
    strike_rate_overall_t(1,1,t) = Sol.strikerate_overall; % ex post strike rate in Home country
    strike_prob_RB_t(1,1,t) = Sol.strikeprob_RB; % probability of strike in rust belt
    strike_prob_SB_t(1,1,t) = Sol.strikeprob_SB; % probability of strike in sun belt
    strike_prob_overall_t(1,1,t) = Sol.strikeprob_overall; % probability of strike in Home country
    P_H_t(1,1,t) = Sol.price_index(1,1); % price index for final goods, home
    P_F_t(1,1,t) = Sol.price_index(2,1); % price index for final goods, foreign
    Y_H_t(1,1,t) = Sol.final_output(1,1); % final output at home
    Y_F_t(1,1,t) = Sol.final_output(2,1); % final output foreign

    %%% Firm-level strike, request and rent series
    %%% 1 row corresponds to home
    %%% colums: firms
    %%% heights: time period
    strike_t(1,:,t) = Sol.vstrike;  % actual strike outcomes
    request_t(1,:,t) = Sol.request;  % union requests
    rent_t(1,:,t) = Sol.rent; % rents paid to unions
    pi_H_t(1,:,t) = Sol.profitsH; % profits by H firms
    pi_F_t(1,:,t) = Sol.profitsF; % profits by F firms

    %%% Firm-country level outcomes
    %%% Series of z shocks, output, profits for all firms and home,foreign
    %%% rows: home (row 1) vs foreign (row 2)
    %%% colums: firms
    %%% heights: time period
    z_t(:,:,t) = mz; % productivity by firm, country
    sector_y_t(:,:,t) = Sol.sector_output; % sector output by firm, country
    i_t(:,:,t) = max(Sol.minvestment,0.00001); %some approximations slightly below zero
    
    %%% Firm-country-destination level outcomes
    %%% outcomes: labor, output, investment, prices
    %%% rows: home firm, home sales (row 1), home firm, foreign sales (row 2),
    %%% foreign firm home sales (row 3), foreign firm, foreign sales (row 4).
    %%% columns: firms
    %%% heights: time periods
    l_t(:,:,t) = Sol.mlabor;
    y_t(:,:,t) = Sol.moutput;
    p_t(:,:,t) = Sol.mprice;

    %%% Print some statistics to the screen
    py_t = p_t.*y_t; %% These import lines compute for all periods, not sure it slows thing down much
    vtau = tau0.*delta_tau.^(0:1:nPeriods-1);
    RB_VA = shiftdim(sum(py_t(1,1:nRBgoods,:)+py_t(2,1:nRBgoods,:),2))'; %%% total value of RB production
    imports_foreign_RB = vtau.*shiftdim(sum(py_t(3,1:nRBgoods,:),2))'; %% total value of foreign RB sales at home
    RB_import_share = imports_foreign_RB./RB_VA;
    Agg_VA = shiftdim(sum(py_t(1,:,:)+py_t(2,:,:),2))'; %%% total value added
    GDP_1 = shiftdim(Y_H_t).*shiftdim(P_H_t);
    imports_agg = vtau.*shiftdim(sum(py_t(3,:,:),2))'; %% total value of foreign RB sales at home
    %%Agg_import_share = imports_agg./Agg_VA;
    Agg_import_share = imports_agg./GDP_1';
    disp(['  Aggregate Import Share = ', num2str(100*Agg_import_share(1,t),'%.1f'),'%']);
    disp(['  Aggregate Imports      = ', num2str(sum(imports_agg'),'%.8f')]);
    disp(['  Rust Belt Import Share = ', num2str(100*RB_import_share(1,t),'%.1f'),'%']);
    RB_emp_share = shiftdim(sum(l_t(1,1:nRBgoods,t)+l_t(2,1:nRBgoods,t)));
    labor_share = (1+sum(Sol.rent,2)-kappa*sum(strike_t(1,:,t).*l_t(1:2,:,t),[1 2]))/GDP_1(t,1);
    disp(['  Labor Share of GDP = ',num2str(100*labor_share,'%.1f'),'%']);
    disp(['  Employment share in Rust Belt = ',num2str(100*RB_emp_share,'%.1f'),'%']);
    disp(['  Strike probability in Rust Belt = ', num2str(100*strike_prob_RB_t(1,1,t),'%.1f'),'%']);
    disp(['  Strike rate in Rust Belt = ', num2str(100*strike_rate_RB_t(1,1,t),'%.1f'),'%']);
    RB_labor = shiftdim(sum(l_t(1,1:nRBgoods,t)+l_t(2,1:nRBgoods,t)))'; %%% wages to RB
    wage_prem_RB = (shiftdim(sum(rent_t(1,1:nRBgoods,t)))-kappa*sum(strike_t(1,:,t))/nRBgoods)./RB_labor'; %%% rents / wages in RB
    disp(['  Rust Belt wage premium = ', num2str(100*wage_prem_RB,'%.1f'),'%']);
    var_log_emp = var(log(l_t(1,:,t)+l_t(2,:,t)));
    disp(['  Variance of log employment = ', num2str(var_log_emp,'%.2f')]);
    I = sum(i_t(1,:,t));
    %%% GDP = sum(p_t(1,:,t).*y_t(1,:,t) + p_t(3,:,t).*y_t(3,:,t))
    GDP = sum(p_t(1,:,t).*y_t(1,:,t) + Par.tau.*p_t(3,:,t).*y_t(3,:,t));
    disp(['  Investment/VA ratio = ', num2str(100*I/GDP,'%.1f'),'%']);
    
    %%% Print growth rates if after period 1
    if t>1
        avg_firm_growth_rate = 100*(mean(z_t(1,:,t)./z_t(1,:,t-1),'all') - 1);
        avg_RB_firm_growth_rate = 100*(mean(z_t(1,1:Par.nRBgoods,t)./z_t(1,1:Par.nRBgoods,t-1),'all') - 1);
        avg_ROC_firm_growth_rate = 100*(mean(z_t(1,Par.nRBgoods+1:Par.ngoods,t)./z_t(1,Par.nRBgoods+1:Par.ngoods,t-1),'all') - 1);
        growth_differential = avg_firm_growth_rate - avg_RB_firm_growth_rate;
        disp(['  Average U.S. firm productivity growth rate = ', num2str(avg_firm_growth_rate,'%.3f'),'%']);
        disp(['  Average RB productivity growth rate = ', num2str(avg_RB_firm_growth_rate,'%.3f'),'%']);
        disp(['  Average ROC productivity growth rate = ', num2str(avg_ROC_firm_growth_rate,'%.3f'),'%']);
        disp(['  U.S. - RB productivity growth differential = ', num2str(growth_differential,'%.3f'),'%']);
    end
    
    %%% Update productivity levels
    m_z_F_boost = ones(2,ngoods); % matrix with foreign productivity boosts
    m_z_F_boost(2,1:nRBgoods) = zeta_F_RB; % matrix with foreign productivity boosts
    mz_sigma_H = mz(1,:).^(sigma-1);
    mz_sigma_F = mz(2,:).^(sigma-1);
    D_Z_H = mean(mz_sigma_H);
    D_Z_F = mean(mz_sigma_F);
    mx_H = max(((i_t(1,:,t).*D_Z_H./mz_sigma_H)./alpha).^(1/gamma),0.00001); %some approximations slightly below zero so take max
    mx_F = max(((i_t(2,:,t).*D_Z_F./mz_sigma_F)./alpha).^(1/gamma),0.00001);
    mx = [mx_H; mx_F];
    mz = mz.*(1+mx).*m_z_F_boost;
    
    %%% Update savings rate
    s = s + s_change;
    Par.s = s;
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Update initial guesses using answers to previous run
    %%% Very important for speed
    %%% with sensible alternatives that do not load previous run, runtime
    %%% can be 5 times longer!
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    period_to_switch = 40; %select period to switch initial guesses from load previous run to adapting current solution
    if t<period_to_switch % guess based on loading solution from previous run
        initial_guess_labor = log(l_t_sol(:,:,t+1)); %guess labor solution from previous run
        initial_guess_ztildeH = log(ztildeH_t_sol(:,:,t+1)); %guess ztilde solution from previous run
        mguess = [initial_guess_labor;initial_guess_ztildeH];
        initial_guess_agg = [log(S1_t_sol(t+1,1)) log(S2_t_sol(t+1,1)) log(wF_t_sol(t+1,1))]; %guess agg sol from previous run
    elseif t>=period_to_switch && t<51 % guess scaling up current solution
        initial_guess_labor = log(l_t_sol(:,:,t+1)); %guess labor solution from previous run -- this seems to works best
        %%ztildeH_scale_up_factor = 1 + avg_firm_growth_rate/100; %guess for how ztilde evolves on average
        ztildeH_scale_up_factor = mean(vztildeH_t(1,:,t)./vztildeH_t(1,:,t-1)); %guess for how ztilde evolves on average
        initial_guess_ztildeH = log(vztildeH_t(1,:,t)*ztildeH_scale_up_factor); % guess for ztildeH -- seems to work better than previous solution
        mguess = [initial_guess_labor;initial_guess_ztildeH];
        %%%wage_F_guess_scale_up_factor = 1 + 0.5*(zeta_F_RB-1); %guess for how foreign wage evolves
        wage_F_guess_scale_up_factor = wF_t(1,1,t)/wF_t(1,1,t-1); %guess that foreign wage evolves at same rate as current period
        S1_guess_scale_up_factor = S1_t(1,t)./S1_t(1,t-1); %guess for how S1 terms evolves
        %%%S1_guess_scale_up_factor = 1/(0.5*wage_F_guess_scale_up_factor + 0.5*ztildeH_scale_up_factor); %guess for how S1 terms evolves
        S2_guess_scale_up_factor = 1.0000; %guess for how S2 terms evolves
        initial_guess_agg = log(agg_sol.*[S1_guess_scale_up_factor S2_guess_scale_up_factor wage_F_guess_scale_up_factor]);
    end
    
    %%% Initial guesses if not loading previous solution
    %%%initial_guess_agg = log(agg_sol); %just guess yesterday's solution
    %%%wage_F_guess_scale_up_factor = 1.0427; %guess for how foreign wage evolves
    %%%S1_guess_scale_up_factor = 0.9755; %guess for how S1 terms evolves
    %%%S2_guess_scale_up_factor = 1.0000; %guess for how S2 terms evolves
    %%%initial_guess_agg = log(exp(initial_guess_agg).*[S1_guess_scale_up_factor S2_guess_scale_up_factor wage_F_guess_scale_up_factor]);
    %%%initial_guess_ztildeH = log(logninv(logncdf(vztildeH,log(z_t(1,:,t))-sigmazH^2/2,sigmazH),log(mz(1,:))-sigmazH^2/2,sigmazH));
    %%%initial_guess_labor = log(mlabor);
    
    
    %%% Save solution if first or last period
     if t==1 %% save first period solution into file
         save ('agg_1950_solution.mat','agg_sol');
         save ('ztildeH_1950_solution.mat','vztildeH');
         sol_mlabor = Sol.mlabor;
         save('labor_1950_solution.mat','sol_mlabor');
         disp('  Saved solution to first period');
     elseif t==51 %% save solution for all periods
         wF_t_sol = shiftdim(wF_t);
         save ('wF_t_solution.mat','wF_t_sol');
         ztildeH_t_sol = vztildeH_t;
         save ('ztildeH_t_solution.mat','ztildeH_t_sol');
         S1_t_sol = shiftdim(S1_t);
         save ('S1_t_solution.mat','S1_t_sol');
         S2_t_sol = shiftdim(S2_t);
         save ('S2_t_solution.mat','S2_t_sol');
         l_t_sol = l_t;
         save ('l_t_solution.mat','l_t_sol');
         disp('  Saved solution to periods 1 through 51');
     end
    
    %%% Lower strike probability in 1980 (period 30)
    %%% ONLY make this change if exactly 50 periods, which is when
    %%% trying to match the data
    if nPeriods==51 && t>=30 && t<=35
        %%% update strike probability
        vphi = vphi*0.712;
    elseif nPeriods==51 && t==36
        zeta_F_RB = 1.00; % exogenous productivity growth of foreign RB
    end
    
    %%% update tau
    Par.tau = Par.tau*delta_tau;
    
    %%% record and display end time for this period
    telapsed = toc(tstart);
    disp(['  Solved in ', num2str(telapsed,'%.1f'),' seconds ']);
    run_time = run_time + telapsed;
    disp(['  Total run time so far is ', num2str(run_time,'%.1f'),' seconds ']);
    disp(['  Average run time per period is ', num2str(run_time/t,'%.1f'),' seconds ']);

end

% Save final results:
save solution_model_CF_s_trend.mat
    
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Make plots, print moments + save
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%% load raw data to compare model vs data
load_data;

%%% make graphs of model vs data
makePlots;

%%% Print model vs data moments
printMoments;

%%%
disp(['Change in Rust Belt`s Employment Share is ',num2str((sum(l_t(1,1:nRBgoods,t)+l_t(2,1:nRBgoods,t))-sum(l_t(1,1:nRBgoods,1)+l_t(2,1:nRBgoods,1)))*100,'%.1f'),' percentage points'])